{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77cbc007",
   "metadata": {},
   "source": [
    "# ONNX Runtime部署-三十类水果-预测单张图像\n",
    "\n",
    "使用推理引擎 ONNX Runtime，读取 onnx 格式的模型文件，对单张图像文件进行预测。\n",
    "\n",
    "同济子豪兄 https://space.bilibili.com/1900783\n",
    "\n",
    "2022-8-22"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01ec88b8",
   "metadata": {},
   "source": [
    "## 导入工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bd1f5efe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: \n",
      "  warn(f\"Failed to load image Python extension: {e}\")\n"
     ]
    }
   ],
   "source": [
    "import onnxruntime\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c74d0dbb",
   "metadata": {},
   "source": [
    "## 载入 onnx 模型，获取 ONNX Runtime 推理器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4de991a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_session = onnxruntime.InferenceSession('resnet18_fruit30.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "355e364d",
   "metadata": {},
   "source": [
    "## 构造输入，获取输出结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "004e44a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(1, 3, 256, 256).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6fd6d88a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 3, 256, 256)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5cdaf58e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# onnx runtime 输入\n",
    "ort_inputs = {'input': x}\n",
    "\n",
    "# onnx runtime 输出\n",
    "ort_output = ort_session.run(['output'], ort_inputs)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a353fa4d",
   "metadata": {},
   "source": [
    "注意，输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da82ec21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 30)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ort_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "53e6c863",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ort_output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7a3d43c",
   "metadata": {},
   "source": [
    "## 预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d2a211ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试集图像预处理-RCTN：缩放裁剪、转 Tensor、归一化\n",
    "test_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                     transforms.CenterCrop(256),\n",
    "                                     transforms.ToTensor(),\n",
    "                                     transforms.Normalize(\n",
    "                                         mean=[0.485, 0.456, 0.406], \n",
    "                                         std=[0.229, 0.224, 0.225])\n",
    "                                    ])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01a55477",
   "metadata": {},
   "source": [
    "### 载入测试图像"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3651ecdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_path = 'test_bananan.jpg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "82477a4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 用 pillow 载入\n",
    "img_pil = Image.open(img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4a9f873b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# img_pil"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f669dbad",
   "metadata": {},
   "source": [
    "## 运行预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a50f48fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_img = test_transform(img_pil)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a3d2e65c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 256, 256])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_img.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4b6bb4fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_tensor = input_img.unsqueeze(0).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "95a356dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 3, 256, 256)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_tensor.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "525f9a98",
   "metadata": {},
   "source": [
    "## ONNX Runtime预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0d22cc6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ONNX Runtime 输入\n",
    "ort_inputs = {'input': input_tensor}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "265e5fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ONNX Runtime 输出\n",
    "pred_logits = ort_session.run(['output'], ort_inputs)[0]\n",
    "pred_logits = torch.tensor(pred_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6392ab62",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 30])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "100f78f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "17410759",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 30])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_softmax.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7678855e",
   "metadata": {},
   "source": [
    "## 解析预测结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b715fce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 取置信度最大的 n 个结果\n",
    "n = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5c8c74c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_n = torch.topk(pred_softmax, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "de08cc60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([[9.9972e-01, 1.1216e-04, 5.0886e-05]]),\n",
       "indices=tensor([[28,  8, 29]]))"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "4ce13a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测类别\n",
    "pred_ids = top_n.indices.numpy()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ee4c6c2c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([28,  8, 29])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d6794361",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测置信度\n",
    "confs = top_n.values.numpy()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ff6d67c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9.9972135e-01, 1.1216471e-04, 5.0885516e-05], dtype=float32)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5268a2c",
   "metadata": {},
   "source": [
    "## 打印预测结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "02e0b5ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 载入类别和对应 ID\n",
    "idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bea9e718",
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx_to_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "8afe3489",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "香蕉     99.972\n",
      "椰子     0.011\n",
      "黄瓜     0.005\n"
     ]
    }
   ],
   "source": [
    "for i in range(n):\n",
    "    class_name = idx_to_labels[pred_ids[i]] # 获取类别名称\n",
    "    confidence = confs[i] * 100             # 获取置信度\n",
    "    text = '{:<6} {:>.3f}'.format(class_name, confidence)\n",
    "    print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e2af8b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
